{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "04c6078a-0398-425c-82f7-d516b01b713d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import asyncio\n",
    "import panel as pn\n",
    "import param\n",
    "\n",
    "from panel.custom import JSComponent, ESMEvent\n",
    "\n",
    "pn.extension('mathjax', template='material')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9aea88fd-3ae9-453f-a66e-b3bdc44c22e3",
   "metadata": {},
   "source": [
    "This example demonstrates how to wrap an external library (specifically [WebLLM](https://github.com/mlc-ai/web-llm)) as a `JSComponent` and interface it with the `ChatInterface`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "957c9fab-3fa7-48d7-83d0-5532bde6e547",
   "metadata": {},
   "outputs": [],
   "source": [
    "MODELS = {\n",
    "    'SmolLM2 (130MB)': 'SmolLM2-360M-Instruct-q4f16_1-MLC',\n",
    "    'TinyLlama-1.1B-Chat (675 MB)': 'TinyLlama-1.1B-Chat-v1.0-q4f16_1-MLC-1k',\n",
    "    'Gemma-2b (2GB)': 'gemma-2-2b-it-q4f16_1-MLC',\n",
    "    'Llama-3.2-3B-Instruct (2.2GB)': 'Llama-3.2-3B-Instruct-q4f16_1-MLC',\n",
    "    'Mistral-7b-Instruct (5GB)': 'Mistral-7B-Instruct-v0.3-q4f16_1-MLC',\n",
    "}\n",
    "\n",
    "class WebLLM(JSComponent):\n",
    "\n",
    "    loaded = param.Boolean(default=False, doc=\"\"\"\n",
    "        Whether the model is loaded.\"\"\")\n",
    "\n",
    "    history = param.Integer(default=3)\n",
    "\n",
    "    status = param.Dict(default={'text': '', 'progress': 0})\n",
    "\n",
    "    load_model = param.Event()\n",
    "\n",
    "    model = param.Selector(default='SmolLM2-360M-Instruct-q4f16_1-MLC', objects=MODELS)\n",
    "\n",
    "    running = param.Boolean(default=False, doc=\"\"\"\n",
    "        Whether the LLM is currently running.\"\"\")\n",
    "    \n",
    "    temperature = param.Number(default=1, bounds=(0, 2), doc=\"\"\"\n",
    "        Temperature of the model completions.\"\"\")\n",
    "\n",
    "    _esm = \"\"\"\n",
    "    import * as webllm from \"https://esm.run/@mlc-ai/web-llm\";\n",
    "\n",
    "    const engines = new Map()\n",
    "\n",
    "    export async function render({ model }) {\n",
    "      model.on(\"msg:custom\", async (event) => {\n",
    "        if (event.type === 'load') {\n",
    "          if (!engines.has(model.model)) {\n",
    "            const initProgressCallback = (status) => {\n",
    "              model.status = status\n",
    "            }\n",
    "            const mlc = await webllm.CreateMLCEngine(\n",
    "               model.model,\n",
    "               {initProgressCallback}\n",
    "            )\n",
    "            engines.set(model.model, mlc)\n",
    "          }\n",
    "          model.loaded = true\n",
    "        } else if (event.type === 'completion') {\n",
    "          const engine = engines.get(model.model)\n",
    "          if (engine == null) {\n",
    "            model.send_msg({'finish_reason': 'error'})\n",
    "          }\n",
    "          const chunks = await engine.chat.completions.create({\n",
    "            messages: event.messages,\n",
    "            temperature: model.temperature ,\n",
    "            stream: true,\n",
    "          })\n",
    "          model.running = true\n",
    "          for await (const chunk of chunks) {\n",
    "            if (!model.running) {\n",
    "              break\n",
    "            }\n",
    "            model.send_msg(chunk.choices[0])\n",
    "          }\n",
    "        }\n",
    "      })\n",
    "    }\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, **params):\n",
    "        super().__init__(**params)\n",
    "        if pn.state.location:\n",
    "            pn.state.location.sync(self, {'model': 'model'})\n",
    "        self._buffer = []\n",
    "\n",
    "    @param.depends('load_model', watch=True)\n",
    "    def _load_model(self):\n",
    "        self.loading = True\n",
    "        self._send_msg({'type': 'load'})\n",
    "\n",
    "    @param.depends('loaded', watch=True)\n",
    "    def _loaded_model(self):\n",
    "        self.loading = False\n",
    "\n",
    "    @param.depends('model', watch=True)\n",
    "    def _update_load_model(self):\n",
    "        self.loaded = False\n",
    "\n",
    "    def _handle_msg(self, msg):\n",
    "        if self.running:\n",
    "            self._buffer.insert(0, msg)\n",
    "\n",
    "    async def create_completion(self, msgs):\n",
    "        self._send_msg({'type': 'completion', 'messages': msgs})\n",
    "        while True:\n",
    "            await asyncio.sleep(0.01)\n",
    "            if not self._buffer:\n",
    "                continue\n",
    "            choice = self._buffer.pop()\n",
    "            yield choice\n",
    "            reason = choice['finish_reason']\n",
    "            if reason == 'error':\n",
    "                raise RuntimeError('Model not loaded')\n",
    "            elif reason:\n",
    "                return\n",
    "\n",
    "    async def callback(self, contents: str, user: str):\n",
    "        if not self.loaded:\n",
    "            if self.loading:\n",
    "                yield pn.pane.Markdown(\n",
    "                    f'## `{self.model}`\\n\\n' + self.param.status.rx()['text']\n",
    "                )\n",
    "            else:\n",
    "                yield 'Load the model'\n",
    "            return\n",
    "        self.running = False\n",
    "        self._buffer.clear()\n",
    "        message = \"\"\n",
    "        async for chunk in self.create_completion([{'role': 'user', 'content': contents}]):\n",
    "            message += chunk['delta'].get('content', '')\n",
    "            yield message\n",
    "\n",
    "    def menu(self):\n",
    "        status = self.param.status.rx()\n",
    "        return pn.Column(\n",
    "            pn.widgets.Select.from_param(self.param.model, sizing_mode='stretch_width'),\n",
    "            pn.widgets.FloatSlider.from_param(self.param.temperature, sizing_mode='stretch_width'),\n",
    "            pn.widgets.Button.from_param(\n",
    "                self.param.load_model, sizing_mode='stretch_width',\n",
    "                disabled=self.param.loaded.rx().rx.or_(self.param.loading)\n",
    "            ),\n",
    "            pn.indicators.Progress(\n",
    "                value=(status['progress']*100).rx.pipe(int), visible=self.param.loading,\n",
    "                sizing_mode='stretch_width'\n",
    "            ),\n",
    "            pn.pane.Markdown(status['text'], visible=self.param.loading)\n",
    "        )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a663e937-797b-468f-875d-5bb8c2af002b",
   "metadata": {},
   "source": [
    "Having implemented the `WebLLM` component we can render the WebLLM UI:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "58269444-868b-41e4-abe2-c4fcf031dc4b",
   "metadata": {},
   "outputs": [],
   "source": [
    "llm = WebLLM()\n",
    "\n",
    "intro = pn.pane.Alert(\"\"\"\n",
    "`WebLLM` runs large-language models entirely in your browser.\n",
    "When visiting the application the first time the model has\n",
    "to be downloaded and loaded into memory, which may take \n",
    "some time. Models are ordered by size (and capability),\n",
    "e.g. SmolLLM is very quick to download but produces poor\n",
    "quality output while Mistral-7b will take a while to\n",
    "download but produces much higher quality output.\n",
    "\"\"\".replace('\\n', ' '))\n",
    "\n",
    "pn.Column(\n",
    "    llm.menu(),\n",
    "    intro,\n",
    "    llm\n",
    ").servable(area='sidebar')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "96229aa4-c5ed-4c4e-944a-789ee65d768f",
   "metadata": {},
   "source": [
    "And connect it to a `ChatInterface`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6f899068-8975-4cf4-9e1d-f3fdb5772a71",
   "metadata": {},
   "outputs": [],
   "source": [
    "chat_interface = pn.chat.ChatInterface(callback=llm.callback)\n",
    "chat_interface.send(\n",
    "    \"Load a model and start chatting.\",\n",
    "    user=\"System\",\n",
    "    respond=False,\n",
    ")\n",
    "\n",
    "llm.param.watch(lambda e: chat_interface.send(f'Loaded `{e.obj.model}`, start chatting!', user='System', respond=False), 'loaded')\n",
    "\n",
    "pn.Row(chat_interface).servable(title='WebLLM')"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python",
   "pygments_lexer": "ipython3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
